name: resnet-distributed-app


resources:
    accelerators: V100

num_nodes: 2

setup: |
    pip3 install --upgrade pip
    git clone https://github.com/michaelzhiluo/pytorch-distributed-resnet
    cd pytorch-distributed-resnet
    # SkyPilot's default image on AWS/GCP has CUDA 11.6 (Azure 11.5).
    pip3 install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
    mkdir -p data  && mkdir -p saved_models && cd data && \
    wget -c --quiet https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
    tar -xvzf cifar-10-python.tar.gz

run: |
    cd pytorch-distributed-resnet

    num_nodes=`echo "$SKYPILOT_NODE_IPS" | wc -l`
    master_addr=`echo "$SKYPILOT_NODE_IPS" | head -n1`
    python3 -m torch.distributed.launch --nproc_per_node=1 \
    --nnodes=$num_nodes --node_rank=${SKYPILOT_NODE_RANK} --master_addr=$master_addr \
    --master_port=8008 resnet_ddp.py --num_epochs 20
